{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "pendulum_collect_data [action_gap].ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      }
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bZCvVbB1gbGH",
        "colab_type": "text"
      },
      "source": [
        "***Copyright 2020 Google LLC.***\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "X_Kwbd8GgW6k",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title Default title text\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GXU9qpMeJwEo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import gym\n",
        "import tensorflow.compat.v2 as tf\n",
        "import numpy as np\n",
        "import pickle\n",
        "import imp\n",
        "import getpass\n",
        "import os\n",
        "import random\n",
        "import string\n",
        "\n",
        "from action_gap_rl import replay\n",
        "from action_gap_rl import value as value_lib\n",
        "from action_gap_rl.policies import layers_lib\n",
        "replay = imp.reload(replay)\n",
        "value_lib = imp.reload(value_lib)\n",
        "layers_lib = imp.reload(layers_lib)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wttrwVVVJ4z6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tf.enable_v2_behavior()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nY0rGKdMJ6hL",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class AttrDict(dict):\n",
        "  def __init__(self, *args, **kwargs):\n",
        "    super(AttrDict, self).__init__(*args, **kwargs)\n",
        "    self.__dict__ = self\n",
        "\n",
        "def to_dict(d):\n",
        "  if isinstance(d, AttrDict):\n",
        "    return {k: to_dict(v) for k, v in d.items()}\n",
        "  return d\n",
        "\n",
        "def filter_bool(lst, mask):\n",
        "  return [lst[i] for i in range(len(lst)) if mask[i]]\n",
        "\n",
        "def rand_str(N):\n",
        "  return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(N))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e6pQzBuRJ9qH",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class BehaviorPolicy(tf.keras.Model):\n",
        "  \"\"\"A policy that takes an arbitrary function as the un-normalized log pdf.\"\"\"\n",
        "\n",
        "  def __init__(self, config, name=None):\n",
        "    super(BehaviorPolicy, self).__init__(\n",
        "        name=name or self.__class__.__name__)\n",
        "    self._config = config\n",
        "    self.num_actions = config.num_actions\n",
        "    if 'initializer' in config:\n",
        "      init = config.initializer\n",
        "    else:\n",
        "      init = tf.keras.initializers.glorot_uniform()\n",
        "    hidden_widths = config.hidden_widths\n",
        "    if config.embed:\n",
        "      transformation_layers = [layers_lib.soft_hot_layer(**config.embed)]\n",
        "    else:\n",
        "      transformation_layers = []\n",
        "    self._body = tf.keras.Sequential(\n",
        "        transformation_layers\n",
        "        + [tf.keras.layers.Dense(w, activation='relu', kernel_initializer=init) for w in hidden_widths]\n",
        "        + [tf.keras.layers.Dense(self.num_actions, activation=None, kernel_initializer=init)]\n",
        "    )\n",
        "\n",
        "  def call(self, states):\n",
        "    return tf.argmax(self._body(tf.expand_dims(states, axis=0)), axis=-1).numpy()[0]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yZv-4FjTixf1",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class ActionAdaptor(object):\n",
        "\n",
        "  def __init__(self, env, actions={0:-2., 1:2.}, t_res=1):\n",
        "    self.env = env\n",
        "    self.actions = actions\n",
        "    self.t_res = t_res\n",
        "    assert t_res >= 1\n",
        "\n",
        "  def step(self, a):\n",
        "    for _ in range(self.t_res):\n",
        "      result = self.env.step([self.actions[a]])\n",
        "    return result\n",
        "\n",
        "  def reset(self):\n",
        "    return self.env.reset()\n",
        "\n",
        "  @property\n",
        "  def unwrapped(self):\n",
        "    return self.env.unwrapped\n",
        "\n",
        "  @property\n",
        "  def action_space(self):\n",
        "    return gym.spaces.Discrete(2)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GMZs6RiWk0ta",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import copy\n",
        "\n",
        "def policy_returns_with_horizon(env, state, policy, horizon, irresolution=1, forced_actions=()):\n",
        "  env = copy.deepcopy(env)\n",
        "  R = 0.\n",
        "  for t in range(horizon):\n",
        "    if t < len(forced_actions):\n",
        "      a = forced_actions[t]\n",
        "    else:\n",
        "      a = policy(state)\n",
        "    for _ in range(irresolution):\n",
        "      state, reward, term, _ = env.step(a)\n",
        "      R += reward\n",
        "      if term: break\n",
        "  return R"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bkm8jwCmfHvw",
        "colab_type": "code",
        "outputId": "9da83c76-eacb-4b75-9d97-e50801a9c7c1",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1582149806054,
          "user_tz": 480,
          "elapsed": 58436,
          "user": {
            "displayName": "Dan Abolafia",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAQmxyT8biPkQeDgz5lf2ocSzJsOqH8BqBuLz1a=s64",
            "userId": "05741471333872541970"
          }
        },
        "colab": {
          "height": 452
        }
      },
      "source": [
        "# TODO: Horizon H returns under optimal and behavior policies.\n",
        "# TODO: more efficient episode sampling using an ensemble of behavior policies\n",
        "\n",
        "WRITE_OUT = True  #@param\n",
        "FILTER = True  #@param\n",
        "compute_behavior_policy_returns = True  #@param\n",
        "## compute_optimal_policy_returns = False  #@param\n",
        "num_episodes = 30  #@param\n",
        "num_datasets = 2  #@param\n",
        "episode_length = 200  #@param\n",
        "temporal_resolution = 10  #@param\n",
        "horizons = [1, 5, 10]  #@param\n",
        "# file_name = \"v3/pendulum_a2_t10_nnp_eval\"  #@param\n",
        "file_name = \"v3/pendulum_test\"  #@param\n",
        "RENDER = False  #@param\n",
        "\n",
        "env = ActionAdaptor(gym.make('Pendulum-v0'))\n",
        "\n",
        "\n",
        "embed=layers_lib.obs_embedding_kwargs(\n",
        "    20,\n",
        "    bounds=((-1,1),(-1,1),(0,2*np.pi)),\n",
        "    variance=[1.]*3,\n",
        "    spillover=0.05,\n",
        ")\n",
        "# embed=None\n",
        "\n",
        "data_keys = []\n",
        "if compute_behavior_policy_returns:\n",
        "  for h in horizons:\n",
        "    data_keys.extend(['pi0_h={}/R0'.format(h), 'pi0_h={}/R1'.format(h)])\n",
        "memory = replay.Memory(data_keys)\n",
        "\n",
        "for dataset_index in range(num_datasets):\n",
        "  print('dataset index =', dataset_index)\n",
        "  for _ in range(num_episodes):\n",
        "    behavior_policy = BehaviorPolicy(AttrDict(\n",
        "        num_actions=2,\n",
        "        initializer=tf.keras.initializers.glorot_normal(),\n",
        "        embed=embed,\n",
        "        hidden_widths=[64]),\n",
        "        name='policy_'+rand_str(10))\n",
        "\n",
        "    # collect a trajectory\n",
        "    obs = env.reset()\n",
        "    memory.log_init(obs)\n",
        "\n",
        "    for _ in range(episode_length // temporal_resolution):\n",
        "      if RENDER: env.render()\n",
        "      act = behavior_policy(obs)\n",
        "      for rep in range(temporal_resolution):\n",
        "        if compute_behavior_policy_returns:\n",
        "          data = {}\n",
        "          for h in horizons:\n",
        "            if rep % h == 0:\n",
        "              r0, r1 = [\n",
        "                    policy_returns_with_horizon(\n",
        "                        env,\n",
        "                        obs,\n",
        "                        behavior_policy,\n",
        "                        horizon=horizon,\n",
        "                        irresolution=temporal_resolution,\n",
        "                        forced_actions=(a,)) \n",
        "                    for a in (0, 1)]\n",
        "              data.update({'pi0_h={}/R0'.format(h): r0, 'pi0_h={}/R1'.format(h): r1})\n",
        "            else:\n",
        "              data.update({'pi0_h={}/R0'.format(h): 0., 'pi0_h={}/R1'.format(h): 0.})\n",
        "        else:\n",
        "          data = {}\n",
        "        next_obs, reward, term, _ = env.step(act)\n",
        "        memory.log_experience(obs, act, reward, next_obs, data=data)\n",
        "        obs = next_obs\n",
        "        if term:\n",
        "          break\n",
        "    if RENDER: env.render()\n",
        "\n",
        "  print('done simulating')\n",
        "\n",
        "  if FILTER:\n",
        "    ma = np.mean(memory.actions, axis=1)\n",
        "    mask = np.logical_and(ma>=0.33, ma<=.66)\n",
        "    print('Num episodes retained:', np.count_nonzero(mask))\n",
        "    print('Returns:', np.sum(memory.rewards, axis=1)[mask].tolist())\n",
        "    memory.observations = filter_bool(memory.observations, mask)\n",
        "    memory.actions = filter_bool(memory.actions, mask)\n",
        "    memory.rewards = filter_bool(memory.rewards, mask)\n",
        "    print('done filtering')\n",
        "\n",
        "  if WRITE_OUT:\n",
        "    s = memory.serialize()\n",
        "\n",
        "    # Make directory.\n",
        "    user = getpass.getuser()\n",
        "    path = '/tmp/action_gap_rl/datasets'.format(user)\n",
        "    os.makedirs(path)\n",
        "\n",
        "    # Save pickle file\n",
        "    with open(\n",
        "        os.path.join(path, '{}.{}.pickle'.format(file_name, dataset_index)),\n",
        "        'wb') as f:\n",
        "      f.write(s)\n",
        "\n",
        "    # Sanity check serialization.\n",
        "    m2 = replay.Memory()\n",
        "    m2.unserialize(s)\n",
        "    print(np.array_equal(m2.entered_states(), memory.entered_states()))\n",
        "    print(np.array_equal(m2.exited_states(), memory.exited_states()))\n",
        "    print(np.array_equal(m2.attempted_actions(), memory.attempted_actions()))\n",
        "    print(np.array_equal(m2.observed_rewards(), memory.observed_rewards()))\n",
        "  \n",
        "  print('\\n\\n')\n",
        "\n"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "dataset index = 0\n",
            "done simulating\n",
            "Num episodes retained: 6\n",
            "Returns: [-1589.4027487638214, -1525.0909899617527, -1285.8440169927087, -1391.7481653357436, -1590.4309787361758, -963.4545953018495]\n",
            "done filtering\n",
            "True\n",
            "True\n",
            "True\n",
            "True\n",
            "\n",
            "\n",
            "\n",
            "dataset index = 1\n",
            "done simulating\n",
            "Num episodes retained: 10\n",
            "Returns: [-1589.4027487638214, -1525.0909899617527, -1285.8440169927087, -1391.7481653357436, -1590.4309787361758, -963.4545953018495, -1619.7998362736216, -1394.871135519214, -1579.172729523288, -1582.0122921012794]\n",
            "done filtering\n",
            "True\n",
            "True\n",
            "True\n",
            "True\n",
            "\n",
            "\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AOmsWcMjbKym",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}